import tilelang.language as T
import tilelang.testing
import torch
import weakref
import gc


def test_tilelang_capture():

    @tilelang.jit(
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        },)
    def get_dummy_kernel():

        @T.prim_func
        def dummy_kernel(a: T.Tensor[(1,), T.float32],):
            with T.Kernel(1) as _:
                a[0] = 1

        return dummy_kernel

    a = torch.randn(1, 1024)
    a_weak = weakref.ref(a)
    _kernel = get_dummy_kernel()
    del a
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()
    a_upgrade = a_weak()
    assert a_upgrade is None, "A is not garbage collected"

    # use objgraph to debug
    # if a_upgrade is not None:
    #     objgraph.show_backrefs([a_upgrade], max_depth=5)


if __name__ == '__main__':
    tilelang.testing.main()
